{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 线性判决"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn\n",
    "import torch.optim"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "两分类判决"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "第0步：loss = 0.693147, W = [0.0, 0.0, 0.0]\n",
      "第10000步：loss = 0.0110425, W = [-1.9234099388122559, 5.306361675262451, -7.200407028198242]\n",
      "第20000步：loss = 0.000326398, W = [-3.5933594703674316, 9.800737380981445, -13.304325103759766]\n",
      "第30000步：loss = 1.08718e-05, W = [-5.200921535491943, 14.169845581054688, -19.308265686035156]\n",
      "第40000步：loss = 4.05311e-07, W = [-6.626287937164307, 18.266353607177734, -25.226253509521484]\n",
      "第50000步：loss = 0, W = [-7.085380554199219, 20.696928024291992, -30.254375457763672]\n",
      "第60000步：loss = 0, W = [-6.967508316040039, 20.697986602783203, -30.83762550354004]\n",
      "第70000步：loss = 0, W = [-6.967508316040039, 20.697986602783203, -30.83762550354004]\n",
      "第80000步：loss = 0, W = [-6.967508316040039, 20.697986602783203, -30.83762550354004]\n",
      "第90000步：loss = 0, W = [-6.967508316040039, 20.697986602783203, -30.83762550354004]\n",
      "第100000步：loss = 0, W = [-6.967508316040039, 20.697986602783203, -30.83762550354004]\n"
     ]
    }
   ],
   "source": [
    "x = torch.tensor([[1., 1., 1.], [2., 3., 1.],\n",
    "        [3., 5., 1.], [4., 2., 1.], [5., 4., 1.]])\n",
    "y = torch.tensor([0., 1., 1., 0., 1.])\n",
    "w = torch.zeros(3, requires_grad=True)\n",
    "\n",
    "criterion = torch.nn.BCEWithLogitsLoss()\n",
    "optimizer = torch.optim.Adam([w,],)\n",
    "\n",
    "for step in range(100001):\n",
    "    if step:\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    pred = torch.mv(x, w)\n",
    "    loss = criterion(pred, y)\n",
    "    if step % 10000 == 0:\n",
    "        print('第{}步：loss = {:g}, W = {}'.format(step, loss, w.tolist()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "多分类判决"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "第0步：loss = 1.09861, W = tensor([[ 0.,  0.,  0.],\n",
      "        [ 0.,  0.,  0.],\n",
      "        [ 0.,  0.,  0.]])\n",
      "第10000步：loss = 0.0353094, W = tensor([[ 4.6957, -3.4924,  2.6249],\n",
      "        [-5.2741,  5.9926,  0.0755],\n",
      "        [ 7.3483, -7.5582,  0.5850]])\n",
      "第20000步：loss = 0.00158344, W = tensor([[ 11.0893,  -9.3732,   7.8006],\n",
      "        [-11.4465,  12.1859,  -2.4046],\n",
      "        [ 13.6819, -13.8749,   1.3359]])\n",
      "第30000步：loss = 7.13686e-05, W = tensor([[ 17.5170, -15.4430,  13.0037],\n",
      "        [-17.6759,  18.3967,  -4.9340],\n",
      "        [ 19.8627, -20.0549,   1.9357]])\n",
      "第40000步：loss = 3.27682e-06, W = tensor([[ 23.6848, -21.5200,  17.9311],\n",
      "        [-23.8189,  24.5445,  -7.3669],\n",
      "        [ 26.0425, -26.2347,   2.5668]])\n",
      "第50000步：loss = 1.92808e-07, W = tensor([[ 28.2911, -26.7368,  21.3005],\n",
      "        [-29.0993,  29.9215,  -9.0855],\n",
      "        [ 32.0491, -32.2796,   3.4274]])\n",
      "第60000步：loss = 3.2848e-08, W = tensor([[ 29.9852, -29.0168,  22.2868],\n",
      "        [-31.7752,  32.7910,  -9.5884],\n",
      "        [ 36.2498, -36.7162,   4.3130]])\n",
      "第70000步：loss = 1.40462e-08, W = tensor([[ 30.5640, -29.8233,  22.5477],\n",
      "        [-32.9160,  34.0327,  -9.7239],\n",
      "        [ 38.3103, -39.0686,   4.8023]])\n",
      "第80000步：loss = 8.55569e-09, W = tensor([[ 30.8763, -30.2519,  22.6703],\n",
      "        [-33.5671,  34.7396,  -9.7873],\n",
      "        [ 39.5035, -40.4826,   5.0915]])\n",
      "第90000步：loss = 6.07732e-09, W = tensor([[ 31.0852, -30.5376,  22.7469],\n",
      "        [-34.0129,  35.2228,  -9.8268],\n",
      "        [ 40.3235, -41.4690,   5.2905]])\n",
      "第100000步：loss = 4.68727e-09, W = tensor([[ 31.2436, -30.7503,  22.8035],\n",
      "        [-34.3517,  35.5881,  -9.8557],\n",
      "        [ 40.9441, -42.2218,   5.4412]])\n"
     ]
    }
   ],
   "source": [
    "x = torch.tensor([[1., 1., 1.], [2., 3., 1.],\n",
    "        [3., 5., 1.], [4., 2., 1.], [5., 4., 1.]])\n",
    "y = torch.tensor([0, 2, 1, 0, 2])\n",
    "w = torch.zeros(3, 3, requires_grad=True)\n",
    "\n",
    "criterion = torch.nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.Adam([w,],)\n",
    "\n",
    "for step in range(100001):\n",
    "    if step:\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    pred = torch.mm(x, w)\n",
    "    loss = criterion(pred, y)\n",
    "    if step % 10000 == 0:\n",
    "        print('第{}步：loss = {:g}, W = {}'.format(step, loss, w))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
